I2I / app.py
mrbeliever's picture
Update app.py
48d37f2 verified
raw
history blame
4.18 kB
import spaces
import gradio as gr
import re
from PIL import Image
import os
import numpy as np
import torch
from diffusers import FluxImg2ImgPipeline
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
def sanitize_prompt(prompt):
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
sanitized_prompt = allowed_chars.sub("", prompt)
return sanitized_prompt
def convert_to_fit_size(original_width_and_height, maximum_size=2048):
width, height = original_width_and_height
if width <= maximum_size and height <= maximum_size:
return width, height
scaling_factor = maximum_size / max(width, height)
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return new_width, new_height
def adjust_to_multiple_of_32(width: int, height: int):
width = width - (width % 32)
height = height - (height % 32)
return width, height
@spaces.GPU(duration=120)
def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Starting")
if image is None or not hasattr(image, 'size'):
raise gr.Error("Please upload an image.")
def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
generator = torch.Generator(device).manual_seed(seed)
width, height = convert_to_fit_size(image.size)
width, height = adjust_to_multiple_of_32(width, height)
image = image.resize((width, height), Image.LANCZOS)
output = pipe(prompt=prompt, image=image, generator=generator, strength=strength, width=width, height=height, guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256)
return output.images[0]
output = process_img2img(image, prompt, strength, seed, inference_step)
return output
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#demo-container {
border: 4px solid black;
border-radius: 8px;
padding: 20px;
margin: 20px auto;
max-width: 800px;
}
#image_upload, #output-img {
border: 4px solid black;
border-radius: 8px;
width: 256px;
height: 256px;
object-fit: cover;
}
#run_button {
font-weight: bold;
border: 4px solid black;
border-radius: 8px;
padding: 10px 20px;
width: 100%
}
#col-left, #col-right {
max-width: 640px;
margin: 0 auto;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
}
.text {
font-size: 16px;
}
"""
with gr.Blocks(css=css, elem_id="demo-container") as demo:
with gr.Column():
gr.HTML(read_file("demo_header.html"))
# Removed or commented out the demo_tools.html line
# gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(width=256, height=256, sources=['upload', 'clipboard'], image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
prompt = gr.Textbox(label="Prompt", value="", placeholder="Your prompt", elem_id="prompt")
btn = gr.Button("Generate", elem_id="run_button", variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength")
seed = gr.Number(value=100, minimum=0, step=1, label="Seed")
inference_step = gr.Number(value=4, minimum=1, step=4, label="Inference Steps")
with gr.Column():
image_out = gr.Image(width=256, height=256, label="Output", elem_id="output-img", format="jpg")
gr.HTML(gr.HTML(read_file("demo_footer.html")))
gr.on(
triggers=[btn.click, prompt.submit],
fn=process_images,
inputs=[image, prompt, strength, seed, inference_step],
outputs=[image_out]
)
if __name__ == "__main__":
demo.queue().launch(show_error=True)