File size: 15,249 Bytes
b273838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import gradio as gr
from functools import partial
import torch
import spaces

import DDCM_blind_face_image_restoration
import latent_DDCM_CCFG
import latent_DDCM_compression
from latent_models import load_model
import os
# import transformers
# transformers.utils.move_cache()


if os.getenv("SPACES_ZERO_GPU") == "true":
    os.environ["SPACES_ZERO_GPU"] = "1"


avail_models = {'512x512': load_model('stabilityai/stable-diffusion-2-1-base', 1000, float16=True, device=torch.device("cpu"), compile=False)[0],
                '768x768': load_model('stabilityai/stable-diffusion-2-1', 1000, float16=True, device=torch.device("cpu"), compile=False)[0]
               }

compression_func = partial(latent_DDCM_compression.main, avail_models=avail_models)


def get_t_and_k_from_file_name(file_name):
    T = int(file_name.split('T')[1].split('-')[0])
    K = int(file_name.split('K')[1].split('-')[0])
    model_type = file_name.split('M')[1].split('-')[0]
    return T, K, model_type


def ccfg(text_input, T, K, ccfg_scale, model_type, compressed_file_in=None):
    return latent_DDCM_CCFG.main(text_input, T, K, min(ccfg_scale, K), model_type, compressed_file_in,
                                 avail_models=avail_models)
    # return latent_DDCM_CCFG.main(text_input, T, K, min(ccfg_scale, K), compressed_file_in)


@spaces.GPU
def decompress_given_bitstream(bitstream, method):
    if bitstream is None:
        gr.Error("Please provide a bit-stream file when performing decompression")
    file_name = bitstream.name
    T, K, model_type = get_t_and_k_from_file_name(file_name)
    if method == 'compression':
        return compression_func(None, T, K, model_type, bitstream)
    elif method == 'blind':
        return DDCM_blind_face_image_restoration.inference(None, T, K, 'NIQE', 1, True, bitstream)
    elif method == 'ccfg':
        return ccfg(None, T, K, -1, model_type, bitstream)
    else:
        raise NotImplementedError()


def validate_K(K):
    if (K & (K - 1)) != 0:
        gr.Warning("For efficient bit usage, K should be a power of 2.")


method_to_func = {
    'compression': partial(decompress_given_bitstream, method='compression'),
    'blind': partial(decompress_given_bitstream, method='blind'),
    'ccfg': partial(decompress_given_bitstream, method='ccfg'),
}

title = "<div style='text-align: center; font-size: 36px; font-weight: bold;'>Compressed Image Generation with Denoising Diffusion Codebook Models</div>"
intro = """
<h3 style="margin-bottom: 10px; text-align: center;">
    <a href="https://ohayonguy.github.io/">Guy Ohayon*</a>&nbsp;,&nbsp;
    <a href="https://hilamanor.github.io/">Hila Manor*</a>&nbsp;,&nbsp;
    <a href="https://tomer.net.technion.ac.il/">Tomer Michaeli</a>&nbsp;,&nbsp;
    <a href="https://elad.cs.technion.ac.il/">Michael Elad</a>
</h3>
<p style="font-size: 12px; text-align: center; margin-bottom: 10px;">
    * Equal contribution
</p>
<h4 style="margin-bottom: 10px; text-align: center;">
    Technion - Israel Institute of Technology
</h5>
<h3 style="margin-bottom: 10px; text-align: center;">
    <a href="https://www.arxiv.org/abs/2502.01189/">[Paper]</a>&nbsp;|&nbsp;
    <a href="https://ddcm-2025.github.io/">[Project Page]</a>&nbsp;|&nbsp;
    <a href="https://github.com/DDCM-2025/ddcm-compressed-image-generation/">[Code]</a>
</h3>
</br></br>
Denoising Diffusion Codebook Models (DDCM) is a novel (and simple) generative approach based on any Denoising Diffusion Model (DDM), that is able to produce high-quality image samples along with their losslessly compressed bit-stream representations.
DDCM can easily be utilized for perceptual image compression, as well as for solving a variety of compressed conditional generation tasks such as text-conditional image generation and image restoration, where each generated sample is accompanied by a compressed bit-stream.
</br></br>
The tabs below correspond to demos of different practical applications. Open each tab to see the application's specific instructions.
</br></br>
<b>Note: The demos below rely on relatively old pre-trained diffusion models such as Stable Diffusion 2.1, simply for the purpose of demonstrating the capabilities of DDCM. Feel free to implement our DDCM-based methods using newer diffusion models to further improve performance.</b>
"""

article = r"""
If you find our work useful, please ⭐ our <a href='https://github.com/DDCM-2025/ddcm-compressed-image-generation' target='_blank'>GitHub repository</a>. Thanks!

πŸ“ **Citation**
```bibtex
@article{ohayon2025compressedimagegenerationdenoising,
      title={Compressed Image Generation with Denoising Diffusion Codebook Models}, 
      author={Guy Ohayon and Hila Manor and Tomer Michaeli and Michael Elad},
      year={2025},
      eprint={2502.01189},
      journal={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2502.01189}, 
}
```

πŸ“‹ **License**
This project is released under the <a rel="license" href="https://github.com/DDCM-2025/ddcm-compressed-image-generation/blob/master/LICENSE">MIT license</a>.

πŸ“§ **Contact**
If you have any questions, please feel free to contact us at <b>[email protected]</b> (Guy Ohayon) and <b>[email protected]</b> (Hila Manor).
"""

custom_css = """
    .tabs button {
        font-size: 21px !important;
        font-weight: bold !important;
    }
"""

with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
    gr.HTML(title)
    gr.HTML(intro)
    # gr.Markdown("# Compressed Image Generation with Denoising Diffusion Codebook Models")

    with gr.Tab("Image Compression"):
        gr.Markdown(
            "- To change the bit rate, modify the number of diffusion timesteps (T) and/or the codebook sizes (K).")
        gr.Markdown("- The input image will be center-cropped and resized to the specified size (512x512 or 768x768).")
        # gr.Markdown("#### Notes:")
        # gr.Markdown('* Since our methods relies on Stable Diffusion, we resize the input image to 512512 pixels')

        with gr.Row():
            with gr.Column(scale=2):
                input_image = gr.Image(label="Input image", scale=2, image_mode='RGB', type='pil')
                with gr.Group():
                    with gr.Row():
                        T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000, scale=2)
                        K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=8192, value=2048, scale=3)
                    with gr.Row():
                        model_type = gr.Radio(["768x768", "512x512"], label="Image size", value="512x512")
                compress = gr.Button("Compress image")

            with gr.Column(scale=3):
                decompressed_image = gr.Image(label="Decompressed image", scale=2)
                compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0)

        compress.click(validate_K, inputs=[K]).then(compression_func, inputs=[input_image, T, K, model_type],
                                                    outputs=[decompressed_image, compressed_file_out])

        gr.Examples([
            ["examples/compression/1.jpg", 1000, 256, '512x512'],
            ["examples/compression/2.jpg", 1000, 256, '512x512'],
            ["examples/compression/4.jpg", 1000, 256, '512x512'],
            ["examples/compression/7.jpg", 1000, 256, '512x512'],
            ["examples/compression/8.jpg", 1000, 256, '512x512'],
            ["examples/compression/13.jpg", 1000, 256, '512x512'],
            ["examples/compression/15.jpg", 1000, 256, '512x512'],
            ["examples/compression/17.jpg", 1000, 256, '512x512'],
            ["examples/compression/18.jpg", 1000, 256, '512x512'],
            ["examples/compression/19.jpg", 1000, 256, '512x512'],
            ["examples/compression/21.jpg", 1000, 256, '512x512'],
            ["examples/compression/22.jpg", 1000, 256, '512x512'],
            ["examples/compression/23.jpg", 1000, 256, '512x512'],
        ],
            inputs=[input_image, T, K, model_type],
            outputs=[decompressed_image, compressed_file_out],
            fn=compression_func,
            cache_examples='lazy')

        gr.Markdown("### Decompress a previously generated bit-stream")
        with gr.Row():
            with gr.Column(scale=2):
                bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
                decompress = gr.Button("Decompress image")

            with gr.Column(scale=3):
                decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)

        decompress.click(method_to_func['compression'], inputs=bitstream, outputs=decompressed_image)

    with gr.Tab("Real-World Face Image Restoration"):
        gr.Markdown(  # "Restore any degraded face image. "
            "Please mark if your input face image is already aligned. "
            "If not, we will try to automatically detect, crop and align the faces, and raise an error if no faces are found. Expect better results if your input image is already aligned.")

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Group():
                    input_image = gr.Image(label="Input image", scale=2, type='filepath')
                    aligned = gr.Checkbox(label='Input face image is aligned')
                with gr.Group():
                    with gr.Row():
                        T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000)
                        K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=8192, value=2048)
                    iqa_metric = gr.Radio(['NIQE', 'TOPIQ', 'CLIP-IQA'], label='Perceptual quality measure to optimize',
                                          value='NIQE')
                    iqa_coef = gr.Number(
                        label="Perception-distortion tradeoff coefficient (Ξ»)",
                        info="Higher -> better perceptual quality",
                        # label="Coefficient controlling the perception-distortion tradeoff (higher means better perceptual quality)",
                        minimum=0, maximum=1, value=1)
                restore = gr.Button("Restore and compress")

            with gr.Column(scale=3):
                decompressed_image = gr.Gallery(label="Restored faces gallery", type="numpy", show_label=True,
                                                format="png")
                compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0, file_count='multiple')

        restore.click(validate_K, inputs=[K]).then(DDCM_blind_face_image_restoration.inference,
                                                   inputs=[input_image, T, K, iqa_metric, iqa_coef, aligned],
                                                   outputs=[decompressed_image, compressed_file_out])
        gr.Examples([
            ["examples/bfr/00000055.png", 1000, 4096, 'TOPIQ', 0.1, True],
            ["examples/bfr/00000085.png", 1000, 4096, 'TOPIQ', 0.1, True],
            ["examples/bfr/00000113.png", 1000, 4096, 'TOPIQ', 0.1, True],
            ["examples/bfr/00000137.png", 1000, 4096, 'TOPIQ', 0.1, True],
            ["examples/bfr/wider/0034.jpg", 1000, 4096, 'NIQE', 1, True],
            ["examples/bfr/webphoto/00042_00.jpg", 1000, 4096, 'TOPIQ', 0.1, True],
            ["examples/bfr/lfw/Ana_Palacio_0001_00.jpg", 1000, 4096, 'TOPIQ', 0.1, True],
            ["examples/bfr/01.png", 1000, 4096, 'NIQE', 0.1, False],
            ["examples/bfr/03.jpg", 1000, 4096, 'TOPIQ', 0.1, False],
        ],
            inputs=[input_image, T, K, iqa_metric, iqa_coef, aligned],
            outputs=[decompressed_image, compressed_file_out],
            fn=DDCM_blind_face_image_restoration.inference,
            cache_examples='lazy')

        gr.Markdown("### Decompress a previously generated bit-stream")
        with gr.Row():
            with gr.Column(scale=2):
                bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
                decompress = gr.Button("Decompress image")

            with gr.Column(scale=3):
                decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)

        decompress.click(method_to_func['blind'], inputs=bitstream, outputs=decompressed_image)

    with gr.Tab("Compressed Text-to-Image Generation"):
        gr.Markdown(
            "This application demonstrates the capabilities of our new *compressed* classifier-free guidance method, which *does not require the input condition for decompression*."
            "  \n"  # newline
            "Each image is generated along with its compressed bit-stream representation, and the input condition is implicitly encoded in the bit-stream.")
        # gr.Markdown("### Generate an image and its compressed bit-stream given an input text prompt")
        # gr.Markdown("#### Notes:")
        # gr.Markdown("* The size of the generated image is 512x512")

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Group():
                    text_input = gr.Textbox(label="Input text prompt", scale=1, value="An image of a dog")
                    with gr.Row():
                        T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000, scale=1)
                        K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=256, value=128, scale=1)
                    K_tilde = gr.Number(label=r"Sub-sampled codebooks' sizes (K̃)", scale=1,
                                        info="Behaves like a guidance scale", minimum=2, maximum=256, value=32)
                    model_type = gr.Radio(["768x768", "512x512"], label="Image size", value="512x512")
                button = gr.Button("Generate and compress")

            with gr.Column(scale=3):
                decompressed_image = gr.Image(label="Generated image", scale=2)
                compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0)

        button.click(validate_K, inputs=[K]).then(ccfg, inputs=[text_input, T, K, K_tilde, model_type],
                                                  outputs=[decompressed_image, compressed_file_out])

        gr.Examples([
            ["An image of a dog", 1000, 64, 4, '512x512'],
            ["Rainbow over the mountains", 1000, 64, 4, '512x512'],
            ["A cat playing soccer", 1000, 64, 4, '512x512'],
        ],
            inputs=[text_input, T, K, K_tilde, model_type],
            outputs=[decompressed_image, compressed_file_out],
            fn=ccfg,
            cache_examples='lazy')
        gr.Markdown("### Decompress a previously generated bit-stream")
        with gr.Row():
            with gr.Column(scale=2):
                bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
                button = gr.Button("Decompress")
            with gr.Column(scale=3):
                decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
        button.click(method_to_func['ccfg'], inputs=bitstream, outputs=decompressed_image)

    gr.Markdown(article)

demo.queue()
demo.launch(state_session_capacity=500)