File size: 2,804 Bytes
c2c42ca
91dd651
d3a77ee
81435cb
91dd651
041f186
c2c42ca
91dd651
 
de56cd9
9d41bd5
 
c2c42ca
9d41bd5
 
 
 
 
91dd651
c2c42ca
 
91dd651
 
 
 
 
 
 
c2c42ca
91dd651
 
 
 
 
 
 
 
 
 
143f063
 
91dd651
 
 
 
825bfd6
91dd651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143f063
91dd651
 
 
143f063
c2c42ca
 
 
 
 
 
 
 
 
 
34cb1b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from gradio_client import Client
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
import torch
import concurrent.futures
import spaces

client_lightning = Client("AP123/SDXL-Lightning")
client_hyper = Client("ByteDance/Hyper-SDXL-1Step-T2I")

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

### SDXL Turbo #### 
pipe_turbo = StableDiffusionXLPipeline.from_pretrained("stabilityai/sdxl-turbo",
                                                       vae=vae,
                                                       torch_dtype=torch.float16,
                                                       variant="fp16"
                                                      )
pipe_turbo.to("cuda")


def get_lighting_result(prompt):
    result_lighting = client_lightning.predict(
        prompt,  # Your prompt
        "1-Step",  # Number of inference steps
        api_name="/generate_image"
    )
    return result_lighting

def get_hyper_result(prompt):
    result_hyper = client_hyper.predict(
        num_images=1,
        height=1024,
        width=1024,
        prompt=prompt,
        seed=3413,
        api_name="/process_image"
    )
    return result_hyper

@spaces.GPU
def get_turbo_result(prompt):
    image_turbo = pipe_turbo(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
    return image_turbo

def run_comparison(prompt):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit tasks to the executor
        future_lighting = executor.submit(get_lighting_result, prompt)
        future_hyper = executor.submit(get_hyper_result, prompt)
        future_turbo = executor.submit(get_turbo_result, prompt)

        # Wait for all futures to complete
        results = concurrent.futures.wait(
            [future_lighting, future_hyper, future_turbo],
            return_when=concurrent.futures.ALL_COMPLETED
        )

        # Extract results from futures
        result_lighting = future_lighting.result()
        result_hyper = future_hyper.result()
        image_turbo = future_turbo.result()
        print(result_lighting)
        print(result_hyper)
        return image_turbo, result_lighting, result_hyper

# Example usage
prompt = "Enter your prompt here"
image_turbo, result_lighting, result_hyper = run_in_parallel(prompt)

css = '''
.gradio-container{max-width: 768px !important}
'''
with gr.Blocks(css=css) as demo:
    prompt = gr.Textbox(label="Prompt")
    run = gr.Button("Run")
    with gr.Row():
        image_turbo = gr.Image(label="SDXL Turbo")
        image_lightning = gr.Image(label="SDXL Lightning")
        image_hyper = gr.Image("Hyper SDXL")
    
    run.click(fn=run_comparison, inputs=prompt, outputs=[image_turbo, image_lightning, image_hyper])