awacke1 commited on
Commit
0159eab
1 Parent(s): 27866da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -39
app.py CHANGED
@@ -15,37 +15,61 @@ from PIL import Image
15
  from io import BytesIO
16
  from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
17
 
18
- # ... [previous imports and setup code remains unchanged]
 
 
 
19
 
20
- # New function to save prompt to history
21
- def save_prompt_to_history(prompt):
22
- with open("prompt_history.txt", "a") as f:
23
- timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
24
- f.write(f"{timestamp}: {prompt}\n")
 
 
 
 
 
25
 
26
- # Modified predict function
27
- def predict(prompt, guidance, steps, seed=1231231):
28
- generator = torch.manual_seed(seed)
29
- last_time = time.time()
30
- results = pipe(
31
- prompt=prompt,
32
- generator=generator,
33
- num_inference_steps=steps,
34
- guidance_scale=guidance,
35
- width=512,
36
- height=512,
37
- output_type="pil",
38
- )
39
- print(f"Pipe took {time.time() - last_time} seconds")
40
-
41
- # Save prompt to history
42
- save_prompt_to_history(prompt)
43
-
44
- # ... [rest of the function remains unchanged]
45
 
46
- return results.images[0] if len(results.images) > 0 else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Modified save_all_images function
49
  def save_all_images(images):
50
  if len(images) == 0:
51
  return None, None
@@ -68,41 +92,205 @@ def save_all_images(images):
68
 
69
  return zip_filename, download_link
70
 
71
- # Function to read prompt history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def read_prompt_history():
73
  if os.path.exists("prompt_history.txt"):
74
  with open("prompt_history.txt", "r") as f:
75
  return f.read()
76
  return "No prompts yet."
77
 
78
- # Modified Gradio interface
79
  with gr.Blocks(css=css) as demo:
80
  with gr.Column(elem_id="container"):
81
- # ... [previous UI components remain unchanged]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Add prompt history display
84
  with gr.Accordion("Prompt History", open=False):
85
  prompt_history = gr.Code(label="Prompt History", language="text", interactive=False)
86
 
87
- # ... [rest of the UI components]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Function to update prompt history display
90
- def update_prompt_history():
91
- return read_prompt_history()
92
 
93
- # Connect components
94
  generate_bt.click(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
 
95
  prompt.submit(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
96
-
97
- # Update prompt history when generating image or when accordion is opened
 
 
 
 
 
98
  generate_bt.click(fn=update_prompt_history, outputs=prompt_history)
99
  prompt.submit(fn=update_prompt_history, outputs=prompt_history)
100
 
101
- # Modify save_all_button click event
102
  save_all_button.click(
103
  fn=lambda: save_all_images([f for f in os.listdir() if f.lower().endswith((".png", ".jpg", ".jpeg"))]),
104
  outputs=[gr.File(), gr.HTML()]
105
  )
 
106
 
107
  demo.queue()
108
  demo.launch(allowed_paths=["/"])
 
15
  from io import BytesIO
16
  from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
17
 
18
+ try:
19
+ import intel_extension_for_pytorch as ipex
20
+ except:
21
+ pass
22
 
23
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
24
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
25
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
26
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
27
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
28
+ device = torch.device(
29
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
30
+ )
31
+ torch_device = device
32
+ torch_dtype = torch.float16
33
 
34
+ # CSS definition
35
+ css = """
36
+ #container{
37
+ margin: 0 auto;
38
+ max-width: 40rem;
39
+ }
40
+ #intro{
41
+ max-width: 100%;
42
+ text-align: center;
43
+ margin: 0 auto;
44
+ }
45
+ """
 
 
 
 
 
 
 
46
 
47
+ def encode_file_to_base64(file_path):
48
+ with open(file_path, "rb") as file:
49
+ encoded = base64.b64encode(file.read()).decode()
50
+ return encoded
51
+
52
+ def create_zip_of_files(files):
53
+ zip_name = "all_files.zip"
54
+ with zipfile.ZipFile(zip_name, 'w') as zipf:
55
+ for file in files:
56
+ zipf.write(file)
57
+ return zip_name
58
+
59
+ def get_zip_download_link(zip_file):
60
+ with open(zip_file, 'rb') as f:
61
+ data = f.read()
62
+ b64 = base64.b64encode(data).decode()
63
+ href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
64
+ return href
65
+
66
+ def clear_all_images():
67
+ base_dir = os.getcwd()
68
+ img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
69
+ for file in img_files:
70
+ os.remove(file)
71
+ print('removed:' + file)
72
 
 
73
  def save_all_images(images):
74
  if len(images) == 0:
75
  return None, None
 
92
 
93
  return zip_filename, download_link
94
 
95
+ def save_all_button_click():
96
+ images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
97
+ zip_filename, download_link = save_all_images(images)
98
+ if download_link:
99
+ return gr.HTML(download_link)
100
+
101
+ def clear_all_button_click():
102
+ clear_all_images()
103
+
104
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
105
+ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
106
+ print(f"device: {device}")
107
+
108
+ if mps_available:
109
+ device = torch.device("mps")
110
+ torch_device = "cpu"
111
+ torch_dtype = torch.float32
112
+
113
+ if SAFETY_CHECKER == "True":
114
+ pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7")
115
+ else:
116
+ pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", safety_checker=None)
117
+
118
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
119
+ pipe.to(device=torch_device, dtype=torch_dtype).to(device)
120
+ pipe.unet.to(memory_format=torch.channels_last)
121
+ pipe.set_progress_bar_config(disable=True)
122
+
123
+ if psutil.virtual_memory().total < 64 * 1024**3:
124
+ pipe.enable_attention_slicing()
125
+
126
+ if TORCH_COMPILE:
127
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
128
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
129
+ pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
130
+
131
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
132
+ pipe.fuse_lora()
133
+
134
+ def safe_filename(text):
135
+ safe_text = re.sub(r'\W+', '_', text)
136
+ timestamp = datetime.datetime.now().strftime("%Y%m%d")
137
+ return f"{safe_text}_{timestamp}.png"
138
+
139
+ def encode_image(image):
140
+ buffered = BytesIO()
141
+ return base64.b64encode(buffered.getvalue()).decode()
142
+
143
+ def fake_gan():
144
+ base_dir = os.getcwd()
145
+ img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
146
+ images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
147
+ return images
148
+
149
+ def save_prompt_to_history(prompt):
150
+ with open("prompt_history.txt", "a") as f:
151
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
152
+ f.write(f"{timestamp}: {prompt}\n")
153
+
154
+ def predict(prompt, guidance, steps, seed=1231231):
155
+ generator = torch.manual_seed(seed)
156
+ last_time = time.time()
157
+ results = pipe(
158
+ prompt=prompt,
159
+ generator=generator,
160
+ num_inference_steps=steps,
161
+ guidance_scale=guidance,
162
+ width=512,
163
+ height=512,
164
+ output_type="pil",
165
+ )
166
+ print(f"Pipe took {time.time() - last_time} seconds")
167
+
168
+ # Save prompt to history
169
+ save_prompt_to_history(prompt)
170
+
171
+ nsfw_content_detected = (
172
+ results.nsfw_content_detected[0]
173
+ if "nsfw_content_detected" in results
174
+ else False
175
+ )
176
+ if nsfw_content_detected:
177
+ nsfw=gr.Button("🕹️NSFW🎨", scale=1)
178
+
179
+ try:
180
+ central = pytz.timezone('US/Central')
181
+ safe_date_time = datetime.datetime.now().strftime("%Y%m%d")
182
+ replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
183
+ safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
184
+ filename = f"{safe_date_time}_{safe_prompt}.png"
185
+
186
+ if len(results.images) > 0:
187
+ image_path = os.path.join("", filename)
188
+ results.images[0].save(image_path)
189
+ print(f"#Image saved as {image_path}")
190
+ gr.File(image_path)
191
+ gr.Button(link=image_path)
192
+ except:
193
+ return results.images[0]
194
+
195
+ return results.images[0] if len(results.images) > 0 else None
196
+
197
  def read_prompt_history():
198
  if os.path.exists("prompt_history.txt"):
199
  with open("prompt_history.txt", "r") as f:
200
  return f.read()
201
  return "No prompts yet."
202
 
 
203
  with gr.Blocks(css=css) as demo:
204
  with gr.Column(elem_id="container"):
205
+ gr.Markdown(
206
+ """4📝RT🖼️Images - 🕹️ Real Time 🎨 Image Generator Gallery 🌐""",
207
+ elem_id="intro",
208
+ )
209
+ with gr.Row():
210
+ with gr.Row():
211
+ prompt = gr.Textbox(
212
+ placeholder="Insert your prompt here:", scale=5, container=False
213
+ )
214
+ generate_bt = gr.Button("Generate", scale=1)
215
+
216
+ gr.Button("Download", link="/file=all_files.zip")
217
+
218
+ image = gr.Image(type="filepath")
219
+
220
+ with gr.Row(variant="compact"):
221
+ text = gr.Textbox(
222
+ label="Image Sets",
223
+ show_label=False,
224
+ max_lines=1,
225
+ placeholder="Enter your prompt",
226
+ )
227
+ btn = gr.Button("Generate Gallery of Saved Images")
228
+
229
+ gallery = gr.Gallery(
230
+ label="Generated Images", show_label=True, elem_id="gallery"
231
+ )
232
+
233
+ with gr.Row(variant="compact"):
234
+ save_all_button = gr.Button("💾 Save All", scale=1)
235
+ clear_all_button = gr.Button("🗑️ Clear All", scale=1)
236
+
237
+ with gr.Accordion("Advanced options", open=False):
238
+ guidance = gr.Slider(
239
+ label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
240
+ )
241
+ steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
242
+ seed = gr.Slider(
243
+ randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
244
+ )
245
 
 
246
  with gr.Accordion("Prompt History", open=False):
247
  prompt_history = gr.Code(label="Prompt History", language="text", interactive=False)
248
 
249
+ with gr.Accordion("Run with diffusers"):
250
+ gr.Markdown(
251
+ """## Running LCM-LoRAs it with `diffusers`
252
+ ```bash
253
+ pip install diffusers==0.23.0
254
+ ```
255
+
256
+ ```py
257
+ from diffusers import DiffusionPipeline, LCMScheduler
258
+ pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda")
259
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
260
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA
261
+ results = pipe(
262
+ prompt="ImageEditor",
263
+ num_inference_steps=4,
264
+ guidance_scale=0.0,
265
+ )
266
+ results.images[0]
267
+ ```
268
+ """
269
+ )
270
 
271
+ with gr.Column():
272
+ file_obj = gr.File(label="Input File")
273
+ input = file_obj
274
 
275
+ inputs = [prompt, guidance, steps, seed]
276
  generate_bt.click(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
277
+ btn.click(fake_gan, None, gallery)
278
  prompt.submit(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
279
+ guidance.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
280
+ steps.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
281
+ seed.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
282
+
283
+ def update_prompt_history():
284
+ return read_prompt_history()
285
+
286
  generate_bt.click(fn=update_prompt_history, outputs=prompt_history)
287
  prompt.submit(fn=update_prompt_history, outputs=prompt_history)
288
 
 
289
  save_all_button.click(
290
  fn=lambda: save_all_images([f for f in os.listdir() if f.lower().endswith((".png", ".jpg", ".jpeg"))]),
291
  outputs=[gr.File(), gr.HTML()]
292
  )
293
+ clear_all_button.click(clear_all_button_click)
294
 
295
  demo.queue()
296
  demo.launch(allowed_paths=["/"])