awacke1's picture
Update app.py
dd404a2 verified
raw
history blame
10.1 kB
import base64
import datetime
import gradio as gr
import numpy as np
import os
import pytz
import psutil
import re
import random
import torch
import time
import shutil
import zipfile
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
try:
import intel_extension_for_pytorch as ipex
except:
pass
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_device = device
torch_dtype = torch.float16
# CSS definition
css = """
#container{
margin: 0 auto;
max-width: 40rem;
}
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
def encode_file_to_base64(file_path):
with open(file_path, "rb") as file:
encoded = base64.b64encode(file.read()).decode()
return encoded
def create_zip_of_files(files):
zip_name = "all_files.zip"
with zipfile.ZipFile(zip_name, 'w') as zipf:
for file in files:
zipf.write(file)
return zip_name
def get_zip_download_link(zip_file):
with open(zip_file, 'rb') as f:
data = f.read()
b64 = base64.b64encode(data).decode()
href = f'<a href="data:application/zip;base64,{b64}" download="{zip_file}">Download All</a>'
return href
def clear_all_images():
base_dir = os.getcwd()
img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
for file in img_files:
os.remove(file)
print('removed:' + file)
def save_all_images(images):
if len(images) == 0:
return None, None
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
zip_filename = f"images_and_history_{timestamp}.zip"
with zipfile.ZipFile(zip_filename, 'w') as zipf:
# Add image files
for file in images:
zipf.write(file, os.path.basename(file))
# Add prompt history file
if os.path.exists("prompt_history.txt"):
zipf.write("prompt_history.txt")
# Generate download link
zip_base64 = encode_file_to_base64(zip_filename)
download_link = f'<a href="data:application/zip;base64,{zip_base64}" download="{zip_filename}">Download All (Images & History)</a>'
return zip_filename, download_link
def save_all_button_click():
images = [file for file in os.listdir() if file.lower().endswith((".png", ".jpg", ".jpeg"))]
zip_filename, download_link = save_all_images(images)
if download_link:
return gr.HTML(download_link)
def clear_all_button_click():
clear_all_images()
print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
print(f"TORCH_COMPILE: {TORCH_COMPILE}")
print(f"device: {device}")
if mps_available:
device = torch.device("mps")
torch_device = "cpu"
torch_dtype = torch.float32
if SAFETY_CHECKER == "True":
pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7")
else:
pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", safety_checker=None)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.to(device=torch_device, dtype=torch_dtype).to(device)
pipe.unet.to(memory_format=torch.channels_last)
pipe.set_progress_bar_config(disable=True)
if psutil.virtual_memory().total < 64 * 1024**3:
pipe.enable_attention_slicing()
if TORCH_COMPILE:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
pipe.fuse_lora()
def safe_filename(text):
safe_text = re.sub(r'\W+', '_', text)
timestamp = datetime.datetime.now().strftime("%Y%m%d")
return f"{safe_text}_{timestamp}.png"
def encode_image(image):
buffered = BytesIO()
return base64.b64encode(buffered.getvalue()).decode()
def fake_gan():
base_dir = os.getcwd()
img_files = [file for file in os.listdir(base_dir) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
images = [(random.choice(img_files), os.path.splitext(file)[0]) for file in img_files]
return images
def save_prompt_to_history(prompt):
with open("prompt_history.txt", "a") as f:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
f.write(f"{timestamp}: {prompt}\n")
def predict(prompt, guidance, steps, seed=1231231):
generator = torch.manual_seed(seed)
last_time = time.time()
results = pipe(
prompt=prompt,
generator=generator,
num_inference_steps=steps,
guidance_scale=guidance,
width=512,
height=512,
output_type="pil",
)
print(f"Pipe took {time.time() - last_time} seconds")
# Save prompt to history
save_prompt_to_history(prompt)
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
nsfw=gr.Button("🕹️NSFW🎨", scale=1)
try:
central = pytz.timezone('US/Central')
safe_date_time = datetime.datetime.now().strftime("%Y%m%d")
replaced_prompt = prompt.replace(" ", "_").replace("\n", "_")
safe_prompt = "".join(x for x in replaced_prompt if x.isalnum() or x == "_")[:90]
filename = f"{safe_date_time}_{safe_prompt}.png"
if len(results.images) > 0:
image_path = os.path.join("", filename)
results.images[0].save(image_path)
print(f"#Image saved as {image_path}")
gr.File(image_path)
gr.Button(link=image_path)
except:
return results.images[0]
return results.images[0] if len(results.images) > 0 else None
def read_prompt_history():
if os.path.exists("prompt_history.txt"):
with open("prompt_history.txt", "r") as f:
return f.read()
return "No prompts yet."
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="container"):
gr.Markdown(
"""4📝RT🖼️Images - 🕹️ Real Time 🎨 Image Generator Gallery 🌐""",
elem_id="intro",
)
with gr.Row():
with gr.Row():
prompt = gr.Textbox(
placeholder="Insert your prompt here:", scale=5, container=False
)
generate_bt = gr.Button("Generate", scale=1)
gr.Button("Download", link="/file=all_files.zip")
image = gr.Image(type="filepath")
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Image Sets",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
)
btn = gr.Button("Generate Gallery of Saved Images")
gallery = gr.Gallery(
label="Generated Images", show_label=True, elem_id="gallery"
)
with gr.Row(variant="compact"):
save_all_button = gr.Button("💾 Save All", scale=1)
clear_all_button = gr.Button("🗑️ Clear All", scale=1)
with gr.Accordion("Advanced options", open=False):
guidance = gr.Slider(
label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
)
steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
seed = gr.Slider(
randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
)
with gr.Accordion("Prompt History", open=False):
prompt_history = gr.Textbox(label="Prompt History", lines=10, max_lines=20, interactive=False)
with gr.Accordion("Run with diffusers"):
gr.Markdown(
"""## Running LCM-LoRAs it with `diffusers`
```bash
pip install diffusers==0.23.0
```
```py
from diffusers import DiffusionPipeline, LCMScheduler
pipe = DiffusionPipeline.from_pretrained("Lykon/dreamshaper-7").to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") #yes, it's a normal LoRA
results = pipe(
prompt="ImageEditor",
num_inference_steps=4,
guidance_scale=0.0,
)
results.images[0]
```
"""
)
with gr.Column():
file_obj = gr.File(label="Input File")
input = file_obj
inputs = [prompt, guidance, steps, seed]
generate_bt.click(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
btn.click(fake_gan, None, gallery)
prompt.submit(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
guidance.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
steps.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
seed.change(fn=predict, inputs=inputs, outputs=[image, prompt_history], show_progress=False)
def update_prompt_history():
return read_prompt_history()
generate_bt.click(fn=update_prompt_history, outputs=prompt_history)
prompt.submit(fn=update_prompt_history, outputs=prompt_history)
save_all_button.click(
fn=lambda: save_all_images([f for f in os.listdir() if f.lower().endswith((".png", ".jpg", ".jpeg"))]),
outputs=[gr.File(), gr.HTML()]
)
clear_all_button.click(clear_all_button_click)
demo.queue()
demo.launch(allowed_paths=["/"])