File size: 4,261 Bytes
a324479
 
 
 
 
 
169ec0c
a324479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169ec0c
 
892096a
 
 
 
 
 
 
 
 
 
 
169ec0c
a324479
 
892096a
a324479
 
 
169ec0c
a324479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3804b82
 
 
a324479
3804b82
a324479
 
 
 
169ec0c
 
 
 
 
 
 
 
 
892096a
 
 
 
 
 
169ec0c
 
 
 
 
 
 
 
 
 
892096a
169ec0c
 
892096a
 
169ec0c
 
 
892096a
169ec0c
 
 
 
 
 
 
 
a324479
169ec0c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import jax
from PIL import Image
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
from diffusers.utils import load_image
import jax.numpy as jnp
import numpy as np


controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    "mfidabel/controlnet-segment-anything", dtype=jnp.float32
)

pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
)

# Add ControlNet params and Replicate
params["controlnet"] = controlnet_params
p_params = replicate(params)

# Description
title = "# 🧨 ControlNet on Segment Anything 🤗"
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).

                Upload a Segment Anything Segmentation Map, write a prompt, and generate images 🤗 This demo is still Work in Progress, so don't expect it to work well for now !! 

                
                Test some of the examples below to give it a try ⬇️
              """

examples = [["a modern main room of a house", "low quality", "examples/condition_image_1.png"],
            ["new york buildings,  Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
            ["contemporary living room,  high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"]]


# Inference Function
def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
    rng = jax.random.PRNGKey(int(seed))
    num_inference_steps = int(num_inference_steps)
    image = Image.fromarray(image, mode="RGB")
    num_samples = max(jax.device_count(), int(num_samples))
    p_rng = jax.random.split(rng, jax.device_count())
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=p_rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images

    output = output.reshape((num_samples,) + output.shape[-3:])

    print(output.shape)
    
    final_image = [np.array(x*255, dtype=np.uint8) for x in output]

    del output
    
    return final_image

with gr.Blocks(css="h1 { text-align: center }") as demo:
    # Title
    gr.Markdown(title)
    # Description
    gr.Markdown(description)

    # Images
    with gr.Row(variant="panel"):
        with gr.Column(scale=2):
            cond_img = gr.Image(label="Input")\
                    .style(height=200)
        with gr.Column(scale=1):
            output = gr.Gallery(label="Generated images")\
                    .style(height=200, rows=[2], columns=[1, 2], object_fit="contain")
        
    # Submit & Clear
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(lines=1, label="Prompt")
            negative_prompt = gr.Textbox(lines=1, label="Negative Prompt")

        with gr.Column():
            with gr.Accordion("Advanced options", open=False):
                num_steps = gr.Slider(10, 60, 50, step=1, label="Steps")
                seed = gr.Slider(0, 1024, 4, step=1, label="Seed")
                num_samples = gr.Slider(1, 4, 4, step=1, label="Nº Samples")
                
            submit = gr.Button("Generate")
            # TODO: Download Button

    # Examples
    gr.Examples(examples=examples,
                inputs=[prompt, negative_prompt, cond_img],
                outputs=output,
                fn=infer,
                cache_examples=True)

    
    submit.click(infer, 
                 inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
                 outputs = output)
    
demo.queue()
demo.launch()