merve's picture
merve HF Staff
Update app.py
6cbd07b
raw
history blame
1.04 kB
from huggingface_hub import from_pretrained_keras
from keras_cv import models
import gradio as gr
from tensorflow import keras
keras.mixed_precision.set_global_policy("mixed_float16")
# prepare model
resolution = 512
sd_dreambooth_model = models.StableDiffusion(
img_width=resolution, img_height=resolution, jit_compile=True,
)
db_diffusion_model = from_pretrained_keras("merve/dreambooth_diffusion_model")
sd_dreambooth_model._diffusion_model = db_diffusion_model
# generate images
def infer(prompt):
generated_images = sd_dreambooth_model.text_to_image(
prompt, batch_size=9
)
return generated_images
output = gr.Gallery(label="Outputs").style(grid=(3,3))
# customize interface
title = "Dreambooth Demo on Dog Images"
description = "This is a dreambooth model fine-tuned on dog images. To try it, input the concept with {sks dog}."
examples=[["sks dog in space"]]
gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples, cache_examples=True).launch()