radames commited on
Commit
9f0c5bb
·
1 Parent(s): 95b0167
Files changed (1) hide show
  1. server/pipelines/img2imgSDXS512.py +175 -0
server/pipelines/img2imgSDXS512.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForImage2Image, AutoencoderTiny
2
+ from compel import Compel
3
+ import torch
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex # type: ignore
7
+ except:
8
+ pass
9
+
10
+ import psutil
11
+ from config import Args
12
+ from pydantic import BaseModel, Field
13
+ from PIL import Image
14
+ import math
15
+
16
+ base_model = "IDKiro/sdxs-512-0.9"
17
+ taesd_model = "madebyollin/taesd"
18
+
19
+ default_prompt = "Portrait of The Terminator with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
20
+ page_content = """
21
+ <h1 class="text-3xl font-bold">Real-Time Latent SDXS</h1>
22
+ <h3 class="text-xl font-bold">Image-to-Image SDXS</h3>
23
+ <p class="text-sm">
24
+ This demo showcases
25
+ <a
26
+ href="https://huggingface.co/blog/lcm_lora"
27
+ target="_blank"
28
+ class="text-blue-500 underline hover:no-underline">LCM</a>
29
+ Image to Image pipeline using
30
+ <a
31
+ href="https://huggingface.co/docs/diffusers/main/en/using-diffusers/lcm#performing-inference-with-lcm"
32
+ target="_blank"
33
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
34
+ > with a MJPEG stream server.
35
+ </p>
36
+ <p class="text-sm text-gray-500">
37
+ Change the prompt to generate different images, accepts <a
38
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
39
+ target="_blank"
40
+ class="text-blue-500 underline hover:no-underline">Compel</a
41
+ > syntax.
42
+ </p>
43
+ """
44
+
45
+
46
+ class Pipeline:
47
+ class Info(BaseModel):
48
+ name: str = "img2img"
49
+ title: str = "Image-to-Image SDXS"
50
+ description: str = "Generates an image from a text prompt"
51
+ input_mode: str = "image"
52
+ page_content: str = page_content
53
+
54
+ class InputParams(BaseModel):
55
+ prompt: str = Field(
56
+ default_prompt,
57
+ title="Prompt",
58
+ field="textarea",
59
+ id="prompt",
60
+ )
61
+ seed: int = Field(
62
+ 2159232, min=0, title="Seed", field="seed", hide=True, id="seed"
63
+ )
64
+ steps: int = Field(
65
+ 1, min=1, max=15, title="Steps", field="range", hide=True, id="steps"
66
+ )
67
+ width: int = Field(
68
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
69
+ )
70
+ height: int = Field(
71
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
72
+ )
73
+ guidance_scale: float = Field(
74
+ 0.0,
75
+ min=0,
76
+ max=20,
77
+ step=0.001,
78
+ title="Guidance Scale",
79
+ field="range",
80
+ hide=True,
81
+ id="guidance_scale",
82
+ )
83
+ strength: float = Field(
84
+ 0.5,
85
+ min=0.25,
86
+ max=1.0,
87
+ step=0.001,
88
+ title="Strength",
89
+ field="range",
90
+ hide=True,
91
+ id="strength",
92
+ )
93
+
94
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
95
+ if args.safety_checker:
96
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(base_model)
97
+ else:
98
+ self.pipe = AutoPipelineForImage2Image.from_pretrained(
99
+ base_model,
100
+ safety_checker=None,
101
+ )
102
+ if args.taesd:
103
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
104
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
105
+ ).to(device)
106
+
107
+ if args.sfast:
108
+ from sfast.compilers.stable_diffusion_pipeline_compiler import (
109
+ compile,
110
+ CompilationConfig,
111
+ )
112
+
113
+ config = CompilationConfig.Default()
114
+ config.enable_xformers = True
115
+ config.enable_triton = True
116
+ config.enable_cuda_graph = True
117
+ self.pipe = compile(self.pipe, config=config)
118
+
119
+ self.pipe.set_progress_bar_config(disable=True)
120
+ self.pipe.to(device=device, dtype=torch_dtype)
121
+ if device.type != "mps":
122
+ self.pipe.unet.to(memory_format=torch.channels_last)
123
+
124
+ if args.torch_compile:
125
+ print("Running torch compile")
126
+ self.pipe.unet = torch.compile(
127
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
128
+ )
129
+ self.pipe.vae = torch.compile(
130
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
131
+ )
132
+
133
+ self.pipe(
134
+ prompt="warmup",
135
+ image=[Image.new("RGB", (768, 768))],
136
+ )
137
+
138
+ if args.compel:
139
+ self.compel_proc = Compel(
140
+ tokenizer=self.pipe.tokenizer,
141
+ text_encoder=self.pipe.text_encoder,
142
+ truncate_long_prompts=False,
143
+ )
144
+
145
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
146
+ generator = torch.manual_seed(params.seed)
147
+ prompt_embeds = None
148
+ prompt = params.prompt
149
+ if hasattr(self, "compel_proc"):
150
+ prompt_embeds = self.compel_proc(params.prompt)
151
+ prompt = None
152
+
153
+ results = self.pipe(
154
+ image=params.image,
155
+ prompt=prompt,
156
+ prompt_embeds=prompt_embeds,
157
+ generator=generator,
158
+ strength=params.strength,
159
+ num_inference_steps=params.steps,
160
+ guidance_scale=params.guidance_scale,
161
+ width=params.width,
162
+ height=params.height,
163
+ output_type="pil",
164
+ )
165
+
166
+ nsfw_content_detected = (
167
+ results.nsfw_content_detected[0]
168
+ if "nsfw_content_detected" in results
169
+ else False
170
+ )
171
+ if nsfw_content_detected:
172
+ return None
173
+ result_image = results.images[0]
174
+
175
+ return result_image