Flourish commited on
Commit
244ab3b
·
verified ·
1 Parent(s): b5fb13c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -55
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
2
  import gradio as gr
3
  import spaces
 
 
4
 
5
  from pipeline import ChatsSDXLPipeline
6
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -11,6 +13,7 @@ from PIL import Image
11
  logging.set_verbosity_error()
12
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
16
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
@@ -20,67 +23,101 @@ pipe = ChatsSDXLPipeline.from_pretrained(
20
  "AIDC-AI/CHATS",
21
  safety_checker=safety_checker,
22
  feature_extractor=feature_extractor,
23
- torch_dtype=torch.float16
24
  )
25
  pipe.to(DEVICE)
26
 
27
- @spaces.GPU
28
- def generate(prompt, steps=50, guidance_scale=7.5, height=768, width=512):
29
- output = pipe(
30
- prompt=prompt,
31
- num_inference_steps=steps,
32
- guidance_scale=guidance_scale,
33
- height=height,
34
- width=width,
35
- seed=0
36
- )
37
- return output['images']
38
- # image = output['images'][0]
39
- # image = Image.fromarray(image)
40
- # return image
41
 
42
- with gr.Blocks(title="🔥 CHATS-SDXL Demo") as demo:
43
- gr.Markdown(
44
- "## CHATS-SDXL Text-to-Image Demo\n\n"
45
- "Enter your prompt and click **Generate Image**. All NSFW content will be automatically filtered."
 
46
  )
47
- with gr.Row():
48
- prompt_input = gr.Textbox(
49
- label="Prompt",
50
- placeholder="Enter your description here...",
51
- lines=2,
52
- )
53
- with gr.Row():
54
- steps_slider = gr.Slider(
55
- minimum=1, maximum=100, value=50, step=1,
56
- label="Inference Steps"
57
- )
58
- scale_slider = gr.Slider(
59
- minimum=1.0, maximum=14.0, value=5.0, step=0.1,
60
- label="Guidance Scale"
61
- )
62
- with gr.Row():
63
- height_slider = gr.Slider(
64
- minimum=64, maximum=2048, value=1024, step=64,
65
- label="Image Height"
66
- )
67
- width_slider = gr.Slider(
68
- minimum=64, maximum=2048, value=1024, step=64,
69
- label="Image Width"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
- generate_button = gr.Button("Generate Image")
72
- gallery = gr.Gallery(
73
- label="Generated Images",
74
- show_label=False,
75
- columns=2,
76
- elem_id="gallery"
77
- )
78
 
79
- generate_button.click(
80
- fn=generate,
81
- inputs=[prompt_input, steps_slider, scale_slider, height_slider, width_slider],
82
- outputs=[gallery],
 
83
  )
84
 
85
- if __name__ == "__main__":
86
- demo.launch()
 
1
  import torch
2
  import gradio as gr
3
  import spaces
4
+ import random
5
+ import numpy as np
6
 
7
  from pipeline import ChatsSDXLPipeline
8
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
13
  logging.set_verbosity_error()
14
 
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MAX_SEED = np.iinfo(np.int32).max
17
 
18
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
19
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
 
23
  "AIDC-AI/CHATS",
24
  safety_checker=safety_checker,
25
  feature_extractor=feature_extractor,
26
+ torch_dtype=torch.bfloat16
27
  )
28
  pipe.to(DEVICE)
29
 
30
+ @spaces.GPU(duration=75)
31
+ def generate(prompt, seed, randomize_seed=False, steps=50, guidance_scale=7.5):
32
+ if randomize_seed:
33
+ seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
34
 
35
+ output = pipe(
36
+ prompt=prompt,
37
+ num_inference_steps=steps,
38
+ guidance_scale=guidance_scale,
39
+ seed=seed
40
  )
41
+ return output['images']
42
+
43
+ examples = [
44
+ "The image is a digital art headshot of an owlfolk character with high detail and dramatic lighting",
45
+ "Solar punk vehicle in a bustling city",
46
+ "An elderly woman poses for a high fashion photoshoot in colorful, patterned clothes with a cyberpunk 2077 vibe",
47
+ ]
48
+
49
+ css="""
50
+ #col-container {
51
+ margin: 0 auto;
52
+ max-width: 520px;
53
+ }
54
+ """
55
+
56
+ with gr.Blocks(css=css) as demo:
57
+
58
+ with gr.Column(elem_id="col-container"):
59
+ gr.Markdown(f"""# CHATS-SDXL
60
+ SDXL diffusion models finetuned using preference optimization framework CHATS. [[paper] (https://arxiv.org/pdf/2502.12579)] [[code](https://github.com/AIDC-AI/CHATS)] [[model](https://huggingface.co/AIDC-AI/CHATS)]
61
+ """)
62
+
63
+ with gr.Row():
64
+
65
+ prompt = gr.Text(
66
+ label="Prompt",
67
+ show_label=False,
68
+ max_lines=1,
69
+ placeholder="Enter your prompt here",
70
+ container=False,
71
+ )
72
+
73
+ run_button = gr.Button("Run", scale=0)
74
+
75
+ result = gr.Image(label="Result", show_label=False)
76
+
77
+ with gr.Accordion("Advanced Settings", open=False):
78
+
79
+ seed = gr.Slider(
80
+ label="Seed",
81
+ minimum=0,
82
+ maximum=MAX_SEED,
83
+ step=1,
84
+ value=0,
85
+ )
86
+
87
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
88
+
89
+ with gr.Row():
90
+
91
+ guidance_scale = gr.Slider(
92
+ label="Guidance Scale",
93
+ minimum=1,
94
+ maximum=14,
95
+ step=0.1,
96
+ value=5.0,
97
+ )
98
+
99
+ num_inference_steps = gr.Slider(
100
+ label="Number of inference steps",
101
+ minimum=1,
102
+ maximum=100,
103
+ step=1,
104
+ value=50,
105
+ )
106
+
107
+ gr.Examples(
108
+ examples = examples,
109
+ fn = generate,
110
+ inputs = [prompt],
111
+ outputs = [result],
112
+ cache_examples="lazy"
113
  )
 
 
 
 
 
 
 
114
 
115
+ gr.on(
116
+ triggers=[run_button.click, prompt.submit],
117
+ fn = generate,
118
+ inputs = [prompt, seed, randomize_seed, num_inference_steps, guidance_scale],
119
+ outputs = [result]
120
  )
121
 
122
+ if __name__ == '__main__':
123
+ demo.launch()