Rooni commited on
Commit
0afd6f0
·
verified ·
1 Parent(s): 7f10b7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -93
app.py CHANGED
@@ -1,97 +1,12 @@
1
  import os
2
- import random
3
- import uuid
4
- import json
5
 
6
- import gradio as gr
7
- import numpy as np
8
- from PIL import Image
9
- import spaces
10
- import torch
11
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
-
13
- # Use environment variables for flexibility
14
- MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
15
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
16
- USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
17
- ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
18
- BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
19
-
20
- # Determine device and load model outside of function for efficiency
21
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
- pipe = StableDiffusionXLPipeline.from_pretrained(
23
- MODEL_ID,
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
- use_safetensors=True,
26
- add_watermarker=False,
27
- ).to(device)
28
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
29
-
30
- # Torch compile for potential speedup (experimental)
31
- if USE_TORCH_COMPILE:
32
- pipe.compile()
33
-
34
- # CPU offloading for larger RAM capacity (experimental)
35
- if ENABLE_CPU_OFFLOAD:
36
- pipe.enable_model_cpu_offload()
37
-
38
- MAX_SEED = np.iinfo(np.int32).max
39
-
40
- def save_image(img):
41
- unique_name = str(uuid.uuid4()) + ".png"
42
- img.save(unique_name)
43
- return unique_name
44
-
45
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
46
- if randomize_seed:
47
- seed = random.randint(0, MAX_SEED)
48
- return seed
49
 
50
- @spaces.GPU(duration=30, queue=False)
51
- def generate(
52
- prompt: str,
53
- negative_prompt: str = "",
54
- use_negative_prompt: bool = False,
55
- seed: int = 1,
56
- width: int = 1024,
57
- height: int = 1024,
58
- guidance_scale: float = 3,
59
- num_inference_steps: int = 30,
60
- randomize_seed: bool = False,
61
- use_resolution_binning: bool = True,
62
- num_images: int = 1, # Number of images to generate
63
- progress=gr.Progress(track_tqdm=True),
64
- ):
65
- seed = int(randomize_seed_fn(seed, randomize_seed))
66
- generator = torch.Generator(device=device).manual_seed(seed)
67
-
68
- # Improved options handling
69
- options = {
70
- "prompt": [prompt] * num_images,
71
- "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
72
- "width": width,
73
- "height": height,
74
- "guidance_scale": guidance_scale,
75
- "num_inference_steps": num_inference_steps,
76
- "generator": generator,
77
- "output_type": "pil",
78
- }
79
-
80
- # Use resolution binning for faster generation with less VRAM usage
81
- if use_resolution_binning:
82
- options["use_resolution_binning"] = True
83
-
84
- # Generate images potentially in batches
85
- images = []
86
- for i in range(0, num_images, BATCH_SIZE):
87
- batch_options = options.copy()
88
- batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
89
- if "negative_prompt" in batch_options:
90
- batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
91
- images.extend(pipe(**batch_options).images)
92
-
93
- image_paths = [save_image(img) for img in images]
94
- return image_paths, seed
95
 
96
  examples = [
97
  "a cat eating a piece of cheese",
@@ -122,7 +37,7 @@ with gr.Blocks(css=css) as demo:
122
  container=False,
123
  )
124
  run_button = gr.Button("Run", scale=0)
125
- result = gr.Gallery(label="Result", columns=1, show_label=False)
126
  with gr.Accordion("Advanced options", open=False):
127
  num_images = gr.Slider(
128
  label="Number of Images",
@@ -144,7 +59,7 @@ with gr.Blocks(css=css) as demo:
144
  seed = gr.Slider(
145
  label="Seed",
146
  minimum=0,
147
- maximum=MAX_SEED,
148
  step=1,
149
  value=0,
150
  )
@@ -193,6 +108,36 @@ with gr.Blocks(css=css) as demo:
193
  api_name=False,
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  gr.on(
197
  triggers=[
198
  prompt.submit,
@@ -217,4 +162,4 @@ with gr.Blocks(css=css) as demo:
217
  )
218
 
219
  if __name__ == "__main__":
220
- demo.queue(max_size=20).launch()
 
1
  import os
2
+ from gradio_client import Client
 
 
3
 
4
+ # Используем переменные окружения для гибкости
5
+ MODEL_ID = os.getenv("MODEL_ID", "KingNish/SDXL-Flash")
 
 
 
 
 
 
 
6
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
7
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Позволяет генерировать несколько изображений за один раз
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ client = Client(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  examples = [
12
  "a cat eating a piece of cheese",
 
37
  container=False,
38
  )
39
  run_button = gr.Button("Run", scale=0)
40
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
41
  with gr.Accordion("Advanced options", open=False):
42
  num_images = gr.Slider(
43
  label="Number of Images",
 
59
  seed = gr.Slider(
60
  label="Seed",
61
  minimum=0,
62
+ maximum=np.iinfo(np.int32).max,
63
  step=1,
64
  value=0,
65
  )
 
108
  api_name=False,
109
  )
110
 
111
+ def generate(
112
+ prompt,
113
+ negative_prompt,
114
+ use_negative_prompt,
115
+ seed,
116
+ width,
117
+ height,
118
+ guidance_scale,
119
+ num_inference_steps,
120
+ randomize_seed,
121
+ num_images,
122
+ ):
123
+ results = []
124
+ for _ in range(num_images):
125
+ response = client.predict(
126
+ prompt=prompt,
127
+ negative_prompt=negative_prompt if use_negative_prompt else "",
128
+ use_negative_prompt=use_negative_prompt,
129
+ seed=seed,
130
+ width=width,
131
+ height=height,
132
+ guidance_scale=guidance_scale,
133
+ num_inference_steps=num_inference_steps,
134
+ randomize_seed=randomize_seed,
135
+ use_resolution_binning=True,
136
+ api_name="/run"
137
+ )
138
+ results.append(response[0]) # Assuming response contains image path or URL
139
+ return results, seed
140
+
141
  gr.on(
142
  triggers=[
143
  prompt.submit,
 
162
  )
163
 
164
  if __name__ == "__main__":
165
+ demo.queue(max_size=20).launch()