badayvedat commited on
Commit
cf12300
·
1 Parent(s): f20ed1f

feat: add local version of lcm model

Browse files
Files changed (7) hide show
  1. README.md +1 -1
  2. app.py +175 -0
  3. constants.py +205 -0
  4. gradio_examples.py +4 -0
  5. model.py +43 -0
  6. style.css +3 -0
  7. utils.py +15 -0
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.4.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from collections import deque
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ from constants import DESCRIPTION, LOGO
10
+ from gradio_examples import EXAMPLES
11
+ from model import get_pipeline
12
+ from utils import replace_background
13
+
14
+ MAX_QUEUE_SIZE = 4
15
+
16
+ pipeline = get_pipeline()
17
+
18
+
19
+ @dataclass
20
+ class GenerationState:
21
+ prompts: deque
22
+ generations: deque
23
+
24
+
25
+ def get_initial_state() -> GenerationState:
26
+ return GenerationState(
27
+ prompts=deque(maxlen=MAX_QUEUE_SIZE),
28
+ generations=deque(maxlen=MAX_QUEUE_SIZE),
29
+ )
30
+
31
+
32
+ def load_initial_state(request: gr.Request) -> GenerationState:
33
+ print("Loading initial state for", request.client.host)
34
+ print("Total number of active threads", threading.active_count())
35
+
36
+ return get_initial_state()
37
+
38
+
39
+ async def put_to_queue(
40
+ image: Optional[Image.Image],
41
+ prompt: str,
42
+ seed: int,
43
+ strength: float,
44
+ state: GenerationState,
45
+ ):
46
+ prompts_queue = state.prompts
47
+
48
+ if prompt and image is not None:
49
+ prompts_queue.append((image, prompt, seed, strength))
50
+
51
+ return state
52
+
53
+
54
+ def inference(state: GenerationState) -> Image.Image:
55
+ prompts_queue = state.prompts
56
+ generations_queue = state.generations
57
+
58
+ if len(prompts_queue) == 0:
59
+ return state
60
+
61
+ image, prompt, seed, strength = prompts_queue.popleft()
62
+
63
+ original_image_size = image.size
64
+ image = replace_background(image.resize((512, 512)))
65
+
66
+ result = pipeline(
67
+ prompt=prompt,
68
+ image=image,
69
+ strength=strength,
70
+ seed=seed,
71
+ guidance_scale=1,
72
+ num_inference_steps=4,
73
+ )
74
+
75
+ output_image = result.images[0].resize(original_image_size)
76
+
77
+ generations_queue.append(output_image)
78
+
79
+ return state
80
+
81
+
82
+ def update_output_image(state: GenerationState):
83
+ image_update = gr.update()
84
+
85
+ generations_queue = state.generations
86
+
87
+ if len(generations_queue) > 0:
88
+ generated_image = generations_queue.popleft()
89
+ image_update = gr.update(value=generated_image)
90
+
91
+ return image_update, state
92
+
93
+
94
+ with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo:
95
+ generation_state = gr.State(get_initial_state())
96
+
97
+ gr.HTML(f'<div style="width: 70px;">{LOGO}</div>')
98
+ gr.Markdown(DESCRIPTION)
99
+ with gr.Row(variant="default"):
100
+ input_image = gr.Image(
101
+ tool="color-sketch",
102
+ source="canvas",
103
+ label="Initial Image",
104
+ type="pil",
105
+ height=512,
106
+ width=512,
107
+ brush_radius=40.0,
108
+ )
109
+
110
+ output_image = gr.Image(
111
+ label="Generated Image",
112
+ type="pil",
113
+ interactive=False,
114
+ elem_id="output_image",
115
+ )
116
+ with gr.Row():
117
+ with gr.Column():
118
+ prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0])
119
+
120
+ with gr.Accordion(label="Advanced Options", open=False):
121
+ with gr.Row():
122
+ with gr.Column():
123
+ strength = gr.Slider(
124
+ label="Strength",
125
+ minimum=0.1,
126
+ maximum=1.0,
127
+ step=0.05,
128
+ value=0.8,
129
+ info="""
130
+ Strength of the initial image that will be applied during inference.
131
+ """,
132
+ )
133
+ with gr.Column():
134
+ seed = gr.Slider(
135
+ label="Seed",
136
+ minimum=0,
137
+ maximum=2**31 - 1,
138
+ step=1,
139
+ randomize=True,
140
+ info="""
141
+ Seed for the random number generator.
142
+ """,
143
+ )
144
+
145
+ demo.load(
146
+ load_initial_state,
147
+ outputs=[generation_state],
148
+ )
149
+ demo.load(
150
+ inference,
151
+ inputs=[generation_state],
152
+ outputs=[generation_state],
153
+ every=0.1,
154
+ )
155
+ demo.load(
156
+ update_output_image,
157
+ inputs=[generation_state],
158
+ outputs=[output_image, generation_state],
159
+ every=0.1,
160
+ )
161
+ for event in [input_image.change, prompt_box.change, strength.change, seed.change]:
162
+ event(
163
+ put_to_queue,
164
+ [input_image, prompt_box, seed, strength, generation_state],
165
+ [generation_state],
166
+ show_progress=False,
167
+ queue=True,
168
+ )
169
+
170
+ gr.Markdown("## Example Prompts")
171
+ gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples")
172
+
173
+
174
+ if __name__ == "__main__":
175
+ demo.queue(concurrency_count=20, api_open=False).launch(max_threads=1024)
constants.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DESCRIPTION = """
2
+ # Real Time Latent Consistency Model
3
+ """
4
+
5
+
6
+ LOGO = """
7
+ <svg
8
+ width="100%"
9
+ height="100%"
10
+ viewBox="0 0 89 32"
11
+ fill="none"
12
+ xmlns="http://www.w3.org/2000/svg"
13
+ >
14
+ <path
15
+ d="M52.308 3.07812H57.8465V4.92428H56.0003V6.77043H54.1541V10.4627H57.8465V12.3089H54.1541V25.232H52.308V27.0781H46.7695V25.232H48.6157V12.3089H46.7695V10.4627H48.6157V6.77043H50.4618V4.92428H52.308V3.07812Z"
16
+ fill="currentColor"
17
+ ></path>
18
+ <path
19
+ d="M79.3849 23.3858H81.2311V25.232H83.0772V27.0781H88.6157V25.232H86.7695V23.3858H84.9234V4.92428H79.3849V23.3858Z"
20
+ fill="currentColor"
21
+ ></path>
22
+ <path
23
+ d="M57.8465 14.155H59.6926V12.3089H61.5388V10.4627H70.7695V12.3089H74.4618V23.3858H76.308V25.232H78.1541V27.0781H72.6157V25.232H70.7695V23.3858H68.9234V14.155H67.0772V12.3089H65.2311V14.155H63.3849V23.3858H65.2311V25.232H67.0772V27.0781H61.5388V25.232H59.6926V23.3858H57.8465V14.155Z"
24
+ fill="currentColor"
25
+ ></path>
26
+ <path
27
+ d="M67.0772 25.232V23.3858H68.9234V25.232H67.0772Z"
28
+ fill="currentColor"
29
+ ></path>
30
+ <rect
31
+ opacity="0.22"
32
+ x="7.38477"
33
+ y="29.5391"
34
+ width="2.46154"
35
+ height="2.46154"
36
+ fill="#5F4CD9"
37
+ ></rect>
38
+ <rect
39
+ opacity="0.85"
40
+ x="2.46094"
41
+ y="19.6914"
42
+ width="12.3077"
43
+ height="2.46154"
44
+ fill="#5F4CD9"
45
+ ></rect>
46
+ <rect
47
+ x="4.92383"
48
+ y="17.2305"
49
+ width="9.84615"
50
+ height="2.46154"
51
+ fill="#5F4CD9"
52
+ ></rect>
53
+ <rect
54
+ opacity="0.4"
55
+ x="7.38477"
56
+ y="27.0781"
57
+ width="4.92308"
58
+ height="2.46154"
59
+ fill="#5F4CD9"
60
+ ></rect>
61
+ <rect
62
+ opacity="0.7"
63
+ y="22.1562"
64
+ width="14.7692"
65
+ height="2.46154"
66
+ fill="#5F4CD9"
67
+ ></rect>
68
+ <rect
69
+ opacity="0.5"
70
+ x="7.38477"
71
+ y="24.6133"
72
+ width="7.38462"
73
+ height="2.46154"
74
+ fill="#5F4CD9"
75
+ ></rect>
76
+ <rect
77
+ opacity="0.22"
78
+ x="7.38477"
79
+ y="12.3086"
80
+ width="2.46154"
81
+ height="2.46154"
82
+ fill="#5F4CD9"
83
+ ></rect>
84
+ <rect
85
+ opacity="0.85"
86
+ x="2.46094"
87
+ y="2.46094"
88
+ width="12.3077"
89
+ height="2.46154"
90
+ fill="#5F4CD9"
91
+ ></rect>
92
+ <rect x="4.92383" width="9.84615" height="2.46154" fill="#5F4CD9"></rect>
93
+ <rect
94
+ opacity="0.4"
95
+ x="7.38477"
96
+ y="9.84375"
97
+ width="4.92308"
98
+ height="2.46154"
99
+ fill="#5F4CD9"
100
+ ></rect>
101
+ <rect
102
+ opacity="0.7"
103
+ y="4.92188"
104
+ width="14.7692"
105
+ height="2.46154"
106
+ fill="#5F4CD9"
107
+ ></rect>
108
+ <rect
109
+ opacity="0.5"
110
+ x="7.38477"
111
+ y="7.38281"
112
+ width="7.38462"
113
+ height="2.46154"
114
+ fill="#5F4CD9"
115
+ ></rect>
116
+ <rect
117
+ opacity="0.22"
118
+ x="24.6152"
119
+ y="29.5391"
120
+ width="2.46154"
121
+ height="2.46154"
122
+ fill="#5F4CD9"
123
+ ></rect>
124
+ <rect
125
+ opacity="0.85"
126
+ x="19.6914"
127
+ y="19.6914"
128
+ width="12.3077"
129
+ height="2.46154"
130
+ fill="#5F4CD9"
131
+ ></rect>
132
+ <rect
133
+ x="22.1543"
134
+ y="17.2305"
135
+ width="9.84615"
136
+ height="2.46154"
137
+ fill="#5F4CD9"
138
+ ></rect>
139
+ <rect
140
+ opacity="0.4"
141
+ x="24.6152"
142
+ y="27.0781"
143
+ width="4.92308"
144
+ height="2.46154"
145
+ fill="#5F4CD9"
146
+ ></rect>
147
+ <rect
148
+ opacity="0.7"
149
+ x="17.2305"
150
+ y="22.1562"
151
+ width="14.7692"
152
+ height="2.46154"
153
+ fill="#5F4CD9"
154
+ ></rect>
155
+ <rect
156
+ opacity="0.5"
157
+ x="24.6152"
158
+ y="24.6133"
159
+ width="7.38462"
160
+ height="2.46154"
161
+ fill="#5F4CD9"
162
+ ></rect>
163
+ <rect
164
+ opacity="0.22"
165
+ x="24.6152"
166
+ y="12.3086"
167
+ width="2.46154"
168
+ height="2.46154"
169
+ fill="#5F4CD9"
170
+ ></rect>
171
+ <rect
172
+ opacity="0.85"
173
+ x="19.6914"
174
+ y="2.46094"
175
+ width="12.3077"
176
+ height="2.46154"
177
+ fill="#5F4CD9"
178
+ ></rect>
179
+ <rect x="22.1543" width="9.84615" height="2.46154" fill="#5F4CD9"></rect>
180
+ <rect
181
+ opacity="0.4"
182
+ x="24.6152"
183
+ y="9.84375"
184
+ width="4.92308"
185
+ height="2.46154"
186
+ fill="#5F4CD9"
187
+ ></rect>
188
+ <rect
189
+ opacity="0.7"
190
+ x="17.2305"
191
+ y="4.92188"
192
+ width="14.7692"
193
+ height="2.46154"
194
+ fill="#5F4CD9"
195
+ ></rect>
196
+ <rect
197
+ opacity="0.5"
198
+ x="24.6152"
199
+ y="7.38281"
200
+ width="7.38462"
201
+ height="2.46154"
202
+ fill="#5F4CD9"
203
+ ></rect>
204
+ </svg>
205
+ """
gradio_examples.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ EXAMPLES = [
2
+ "a house on the water, oil painting",
3
+ "a sunset at a tropical beach with palm trees",
4
+ ]
model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+
4
+ def get_pipeline():
5
+ import torch
6
+ from diffusers import AutoencoderTiny, AutoPipelineForImage2Image
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ pipe = AutoPipelineForImage2Image.from_pretrained(
11
+ "SimianLuo/LCM_Dreamshaper_v7",
12
+ use_safetensors=True,
13
+ )
14
+ pipe.vae = AutoencoderTiny.from_pretrained(
15
+ "madebyollin/taesd",
16
+ torch_dtype=torch.float16,
17
+ use_safetensors=True,
18
+ )
19
+ pipe = pipe.to(device, dtype=torch.float16)
20
+ pipe.unet.to(memory_format=torch.channels_last)
21
+ return pipe
22
+
23
+
24
+ def get_test_pipeline():
25
+ from PIL import Image
26
+ from dataclasses import dataclass
27
+ import random
28
+ import time
29
+
30
+ @dataclass
31
+ class Images:
32
+ images: list[Image.Image]
33
+
34
+ class Pipeline:
35
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
36
+ time.sleep(0.5)
37
+ r = random.randint(0, 255)
38
+ g = random.randint(0, 255)
39
+ b = random.randint(0, 255)
40
+
41
+ return Images(images=[Image.new("RGB", (512, 512), color=(r, g, b))])
42
+
43
+ return Pipeline()
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+
4
+
5
+ def replace_background(image: Image.Image, new_background_color=(0, 255, 255)):
6
+ image_np = np.array(image)
7
+
8
+ white_threshold = 255 * 3
9
+ white_pixels = np.sum(image_np, axis=-1) == white_threshold
10
+
11
+ image_np[white_pixels] = new_background_color
12
+
13
+ result = Image.fromarray(image_np)
14
+
15
+ return result