import base64 import datetime import gradio as gr import numpy as np import os import pytz import psutil import re import random import torch import time import shutil import zipfile from PIL import Image from io import BytesIO from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny try: import intel_extension_for_pytorch as ipex except: pass SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) HF_TOKEN = os.environ.get("HF_TOKEN", None) mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available() device = torch.device( "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu" ) torch_device = device torch_dtype = torch.float16 # CSS definition css = """ #container{ margin: 0 auto; max-width: 40rem; } #intro{ max-width: 100%; text-align: center; margin: 0 auto; } """ def encode_file_to_base64(file_path): with open(file_path, "rb") as file: encoded = base64.b64encode(file.read()).decode() return encoded def create_zip_of_files(files): zip_name = "all_files.zip" with zipfile.ZipFile(zip_name, 'w') as zipf: for file in files: zipf.write(file) return zip_name def get_zip_download_link(zip_file): with open(zip_file, 'rb') as f: data = f.read() b64 = base64.b64encode(data).decode() href = f'Download All' return href def clear_all_images(): base_dir = os.getcwd() img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] for file in img_files: os.remove(file) print('removed:' + file) def save_all_images(images): if len(images) == 0: return None, None timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") zip_filename = f"images_and_history_{timestamp}.zip" with zipfile.ZipFile(zip_filename, 'w') as zipf: # Add image files for file in images: zipf.write(file, os.path.basename(file)) # Add prompt history file if os.path.exists("prompt_history.txt"): zipf.write("prompt_history.txt") # Generate download link zip_base64 = encode_file_to_base64(zip_filename) download_link = f'Download All (Images & History)' return zip_filename, download_link def save_all_button_click(): images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))] zip_filename, download_link = save_all_images(images) if download_link: return gr.HTML(download_link) def clear_all_button_click(): clear_all_images() print(f"SAFETY_CHECKER: {SAFETY_CHECKER}") print(f"TORCH_COMPILE: {TORCH_COMPILE}") print(f"device: {device}") if mps_available: device = torch.device("mps") torch_device = "cpu" torch_dtype = torch.float32 if SAFETY_CHECKER == "True": pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7") else: pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", safety_checker=None) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.to(device=torch_device, dtype=torch_dtype).to(device) pipe.unet.to(memory_format=torch.channels_last) pipe.set_progress_bar_config(disable=True) if psutil.virtual_memory().total < 64 * 1024**3: pipe.enable_attention_slicing() if TORCH_COMPILE: pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True) pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0) pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") pipe.fuse_lora() def safe_filename(text): safe_text = re.sub(r'\W+', '_', text) timestamp = datetime.datetime.now().strftime("%Y%m%d") return f"{safe_text}_{timestamp}.png" def encode_image(image): buffered = BytesIO() return base64.b64encode(buffered.getvalue()).decode() def fake_gan(): base_dir = os.getcwd() img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))] images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files] return images def save_prompt_to_history(prompt): with open("prompt_history.txt", "a") as f: timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") f.write(f"{timestamp}: {prompt}\n") def predict(prompt, guidance, steps, seed=1231231): generator = torch.manual_seed(seed) last_time = time.time() results = pipe( prompt=prompt, generator=generator, num_inference_steps=steps, guidance_scale=guidance, width=512, height=512, output_type="pil", ) print(f"Pipe took {time.time() - last_time} seconds") # Save prompt to history save_prompt_to_history(prompt) nsfw_content_detected = ( results.nsfw_content_detected[0] if "nsfw_content_detected" in results else False ) if nsfw_content_detected: nsfw=gr.Button("🕹️NSFW🎨", scale=1) try: central = pytz.timezone('US/Central') safe_date_time = datetime.datetime.now().strftime("%Y%m%d") replaced_prompt = prompt.replace(" ", "_").replace("\n", "_") safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90] filename = f"{safe_date_time}_{safe_prompt}.png" if len(results.images) > 0: image_path = os.path.join("", filename) results.images[0].save(image_path) print(f"#Image saved as {image_path}") gr.File(image_path) gr.Button(link=image_path) except: return results.images[0] return results.images[0] if len(results.images) > 0 else None def read_prompt_history(): if os.path.exists("prompt_history.txt"): with open("prompt_history.txt", "r") as f: return f.read() return "No prompts yet." with gr.Blocks(css=css) as demo: with gr.Column(elem_id="container"): gr.Markdown( """4📝RT🖼️Images - 🕹️ Real Time 🎨 Image Generator Gallery 🌐""", elem_id="intro", ) with gr.Row(): with gr.Row(): prompt = gr.Textbox( placeholder="Insert your prompt here:", scale=5, container=False ) generate_bt = gr.Button("Generate", scale=1) gr.Button("Download", link="/file=all_files.zip") image = gr.Image(type="filepath") with gr.Row(variant="compact"): text = gr.Textbox( label="Image Sets", show_label=False, max_lines=1, placeholder="Enter your prompt", ) btn = gr.Button("Generate Gallery of Saved Images") gallery = gr.Gallery( label="Generated Images", show_label=True, elem_id="gallery" ) with gr.Row(variant="compact"): save_all_button = gr.Button("💾 Save All", scale=1) clear_all_button = gr.Button("🗑️ Clear All", scale=1) with gr.Accordion("Advanced options", open=False): guidance = gr.Slider( label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001 ) steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1) seed = gr.Slider( randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1 ) with gr.Accordion("Prompt History", open=False): prompt_history = gr.Textbox(label="Prompt History", lines=10, max_lines=20, interactive=False) with gr.Accordion("Run with diffusers"): gr.Markdown( """## Running LCM-LoRAs it with `diffusers` ```bash pip install diffusers==0.23.0 ``` ```py from diffusers import DiffusionPipeline, LCMScheduler pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda") pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA results = pipe( prompt="ImageEditor", num_inference_steps=4, guidance_scale=0.0, ) results.images[0] ``` """ ) with gr.Column(): file_obj = gr.File(label="Input File") input = file_obj inputs = [prompt, guidance, steps, seed] generate_bt.click(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False) btn.click(fake_gan, None, gallery) prompt.submit(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False) guidance.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False) steps.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False) seed.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False) def update_prompt_history(): return read_prompt_history() generate_bt.click(fn=update_prompt_history, outputs=prompt_history) prompt.submit(fn=update_prompt_history, outputs=prompt_history) save_all_button.click( fn=lambda: save_all_images([f for f in os.listdir() if f.lower().endswith((".png", ".jpg", ".jpeg"))]), outputs=[gr.File(), gr.HTML()] ) clear_all_button.click(clear_all_button_click) demo.queue() demo.launch(allowed_paths=["/"])