Spaces:
Runtime error
Runtime error
File size: 3,959 Bytes
29958a2 1da736d 36dfe3a 29958a2 0d7e4cb 8565ee2 29958a2 eeddd9f 8b6e253 29958a2 fad62be 29958a2 e6cf7d1 dbd277d 29958a2 b1eac09 29958a2 fad62be 859c26b 4184307 57596b6 29958a2 eeddd9f b485002 29958a2 d3fc59d bd11ec1 1da736d c0eeba9 5898ca5 a1fff46 0d7e4cb 29958a2 eeddd9f 96b24ec 29958a2 eeddd9f f55a34e 3cab9bb 1da736d 1154a9d 1da736d eeddd9f 88e83d0 eeddd9f 1e12c79 eeddd9f 29958a2 cdf3ee3 eeddd9f 88e83d0 eeddd9f c0eeba9 eeddd9f 88e83d0 c848fb3 29958a2 c848fb3 29958a2 858501a bfd5e5b 858501a d860de2 4f48d7a d860de2 858501a 29958a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
from diffusers import FlaxScoreSdeVeScheduler, FlaxDPMSolverMultistepScheduler
import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torchvision
import torchvision.transforms as T
from flax.jax_utils import replicate
from flax.training.common_utils import shard
#from torchvision.transforms import v2 as T2
import cv2
import PIL
from PIL import Image
import numpy as np
import jax
import torchvision.transforms.functional as F
output_res = (768,768)
conditioning_image_transforms = T.Compose(
[
#T2.ScaleJitter(target_size=output_res, scale_range=(0.5, 3.0))),
T.RandomCrop(size=output_res, pad_if_needed=True, padding_mode="symmetric"),
T.ToTensor(),
#T.Normalize([0.5], [0.5]),
]
)
cnet, cnet_params = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
"./models/wd-1-5-b2-flax",
controlnet=cnet,
revision="flax",
dtype=jnp.bfloat16,
)
#scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
# "./models/wd-1-5-b2-flax",
# subfolder="scheduler"
#)
#params["scheduler"] = scheduler_state
#scheduler = FlaxDPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe.enable_xformers_memory_efficient_attention()
def get_random(seed):
return jax.random.PRNGKey(seed)
# inference function takes prompt, negative prompt and image
def infer(prompt, negative_prompt, image):
# implement your inference function here
params["controlnet"] = cnet_params
num_samples = 1
inp = Image.fromarray(image)
cond_input = conditioning_image_transforms(inp)
cond_input = T.ToPILImage()(cond_input)
cond_img_in = pipe.prepare_image_inputs([cond_input] * num_samples)
cond_img_in = shard(cond_img_in)
prompt_in = pipe.prepare_text_inputs([prompt] * num_samples)
prompt_in = shard(prompt_in)
n_prompt_in = pipe.prepare_text_inputs([negative_prompt] * num_samples)
n_prompt_in = shard(n_prompt_in)
rng = get_random(0)
rng = jax.random.split(rng, jax.device_count())
p_params = replicate(params)
output = pipe(
prompt_ids=prompt_in,
image=cond_img_in,
params=p_params,
prng_seed=rng,
num_inference_steps=70,
neg_prompt_ids=n_prompt_in,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
gr.Interface(
infer,
inputs=[
gr.Textbox(
label="Enter prompt",
max_lines=1,
placeholder="1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck",
),
gr.Textbox(
label="Enter negative prompt",
max_lines=1,
placeholder="low quality",
),
gr.Image(),
],
outputs=gr.Gallery().style(grid=[2], height="auto"),
title="Generate controlled outputs with Categorical Conditioning on Waifu Diffusion 1.5 beta 2.",
description="This Space uses image examples as style conditioning.",
examples=[
["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "realistic, real life", "wikipe_cond_1.png"],
["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "realistic, real life", "wikipe_cond_2.png"],
["1girl, green hair, sweater, looking at viewer, upper body, beanie, outdoors, watercolor, night, turtleneck", "realistic, real life", "wikipe_cond_3.png"]
],
allow_flagging=False,
).launch(enable_queue=True)
|