File size: 3,775 Bytes
c7db14f
 
7fd9cc7
244ab3b
 
c7db14f
 
 
 
 
 
 
 
 
 
244ab3b
c7db14f
 
 
 
 
 
 
 
 
244ab3b
c7db14f
 
 
244ab3b
586c31c
244ab3b
 
586c31c
 
244ab3b
 
 
 
 
c7db14f
586c31c
244ab3b
 
 
586c31c
244ab3b
 
 
 
 
 
 
 
 
 
 
 
 
 
586c31c
244ab3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586c31c
244ab3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7db14f
 
244ab3b
 
 
 
 
c7db14f
 
244ab3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import gradio as gr
import spaces
import random
import numpy as np

from pipeline import ChatsSDXLPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
from diffusers.utils import logging
from PIL import Image

logging.set_verbosity_error()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max

feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")

# Load CHATS-SDXL pipeline
pipe = ChatsSDXLPipeline.from_pretrained(
        "AIDC-AI/CHATS",
        safety_checker=safety_checker,
        feature_extractor=feature_extractor,
        torch_dtype=torch.bfloat16
)
pipe.to(DEVICE)

@spaces.GPU(duration=75)
def generate(prompt, seed=0, randomize_seed=False, steps=50, guidance_scale=5.0):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    print('inference with prompt : {}, seed : {}, step : {}, cfg : {}'.format(prompt, seed, steps, guidance_scale))
    output = pipe(
        prompt=prompt,
        num_inference_steps=steps,
        guidance_scale=guidance_scale,
        seed=seed
    )
    return output['images'][0]

examples = [
    "Solar punk vehicle in a bustling city",
    "An anthropomorphic cat riding a Harley Davidson in Arizona with sunglasses and a leather jacket",
    "An elderly woman poses for a high fashion photoshoot in colorful, patterned clothes with a cyberpunk 2077 vibe",
]

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# CHATS-SDXL
SDXL diffusion models finetuned using preference optimization framework CHATS. [[paper](https://arxiv.org/pdf/2502.12579)] [[code](https://github.com/AIDC-AI/CHATS)] [[model](https://huggingface.co/AIDC-AI/CHATS)]
        """)
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt here",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
                        
            with gr.Row():

                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=14,
                    step=0.1,
                    value=5.0,
                )
  
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=100,
                    step=1,
                    value=50,
                )
        
        gr.Examples(
            examples = examples,
            fn = generate,
            inputs = [prompt],
            outputs = [result],
            cache_examples="lazy"
        )

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = generate,
        inputs = [prompt, seed, randomize_seed, num_inference_steps, guidance_scale],
        outputs = [result]
    )

if __name__ == '__main__':
    demo.launch()