File size: 3,520 Bytes
13ce7ee
 
fc1393b
9b4449a
fc1393b
 
 
 
 
 
 
8772d2d
 
9b4449a
 
fc1393b
9966900
 
fc1393b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0dbac
 
 
 
 
 
 
 
 
 
 
 
f023e9e
fc1393b
4b80f72
 
fc1393b
4b80f72
 
 
 
 
 
 
 
 
 
 
fc1393b
 
 
 
 
4b80f72
 
fc1393b
 
 
4b80f72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc1393b
 
 
 
 
 
 
 
 
 
 
4b80f72
fc1393b
4b80f72
fc1393b
 
 
 
4b80f72
 
fc1393b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import spaces
import gradio as gr
import torch
import os

from diffusers import (
    DDPMScheduler,
    StableDiffusionXLImg2ImgPipeline,
    AutoencoderKL,
)

from diffusers.utils import load_image

os.system("pip install torch_tensorrt==2.4.0")

BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"--------->Device: {device}")

vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", 
    torch_dtype=torch.float16,
)

base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    BASE_MODEL, 
    vae=vae,
    torch_dtype=torch.float16, 
    variant="fp16", 
    use_safetensors=True,
)
base_pipe = base_pipe.to(device, silence_dtype_warnings=True)
base_pipe.scheduler = DDPMScheduler.from_pretrained(
    BASE_MODEL,
    subfolder="scheduler",
)

backend = "torch_tensorrt"
import torch_tensorrt
print('Compiling model...')
compiledModel = torch.compile(
    base_pipe.unet,
    backend=backend,
    options={
        "truncate_long_and_double": True,
        "enabled_precisions": {torch.float32, torch.float16},
    },
    dynamic=False,
)

base_pipe.unet = compiledModel

import torch._dynamo
torch._dynamo.config.suppress_errors = True

try:
    init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png")
    generated_image = base_pipe(
        image=init_image,
        prompt="A white cat",
        num_inference_steps=5,
    ).images[0]

    generated_image.save("/tmp/gradio/generated_image.png")
except Exception as e:
    print(f"Error: {e}")


def create_demo() -> gr.Blocks:

    @spaces.GPU(duration=30)
    def image_to_image(
        image: gr.Image,
        prompt:str,
        steps:int,
    ):
        run_task_time = 0
        time_cost_str = ''
        run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
        generated_image = base_pipe(
            image=image,
            prompt=prompt,
            num_inference_steps=steps,
        ).images[0]
        run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
        return generated_image
    
    def get_time_cost(run_task_time, time_cost_str):
        now_time = int(time.time()*1000)
        if run_task_time == 0:
            time_cost_str = 'start'
        else:
            if time_cost_str != '': 
                time_cost_str += f'-->'
            time_cost_str += f'{now_time - run_task_time}'
        run_task_time = now_time
        return run_task_time, time_cost_str

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", placeholder="Write a prompt here", lines=2, value="A beautiful sunset over the city")
            with gr.Column():
                steps = gr.Slider(minimum=1, maximum=100, value=5, step=1, label="Num Steps")
                g_btn = gr.Button("Generate")
                
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="pil", interactive=True)
            with gr.Column():
                generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
                time_cost = gr.Textbox(label="Time Cost", lines=1, interactive=False)
        
        g_btn.click(
            fn=text_to_image,
            inputs=[input_image, prompt, steps],
            outputs=[generated_image, time_cost],
        )

    return demo